Skip to content

[Fix] Mixed Precision (float16) numerical instability in GroupNormalization with small epsilon#22589

Open
ChiragSW wants to merge 6 commits intokeras-team:masterfrom
ChiragSW:issue#22586
Open

[Fix] Mixed Precision (float16) numerical instability in GroupNormalization with small epsilon#22589
ChiragSW wants to merge 6 commits intokeras-team:masterfrom
ChiragSW:issue#22586

Conversation

@ChiragSW
Copy link
Copy Markdown
Contributor

@ChiragSW ChiragSW commented Mar 28, 2026

Root Cause

With autocast=True (it is true in default), inputs were cast to float16 before reaching call(). Values exceeding float16's max (65504) overflowed to inf, causing NaN propagation through normalization math. The existing internal float32 upcast couldn't recover already-lost values.

Fix

self.autocast = False keeps inputs in their original dtype (float32), preventing overflow
autocast=False on gamma/beta weights stores weights in float32 for precision
ops.cast(outputs, self.compute_dtype) returns proper float16 output for mixed precision

I also added regression tests:

  1. test_large_value_within_autocast_scope : verifies weights aren't corrupted by autocast (same test in BatchNormalization and LayerNormalization)
  2. test_mixed_float16_large_inputs : catches actual NaN bug
  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the GroupNormalization layer to improve numerical stability during mixed precision training. It disables automatic casting for the layer and its weights (gamma and beta) and adds an explicit cast to compute_dtype at the end of the call method. New test cases are included to ensure that large input values do not result in NaNs when running within a float16 autocast scope. I have no feedback to provide.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 28, 2026

Codecov Report

❌ Patch coverage is 18.18182% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 24.73%. Comparing base (28a83c5) to head (ddb80ff).
⚠️ Report is 26 commits behind head on master.

Files with missing lines Patch % Lines
...as/src/layers/normalization/group_normalization.py 18.18% 9 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (28a83c5) and HEAD (ddb80ff). Click for more details.

HEAD has 10 uploads less than BASE
Flag BASE (28a83c5) HEAD (ddb80ff)
keras 6 1
keras-jax 2 1
keras-numpy 1 0
keras-openvino 1 0
keras-torch 1 0
keras-tensorflow 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #22589       +/-   ##
===========================================
- Coverage   83.28%   24.73%   -58.56%     
===========================================
  Files         596      596               
  Lines       68089    68254      +165     
  Branches    10607    10668       +61     
===========================================
- Hits        56711    16884    -39827     
- Misses       8634    50456    +41822     
+ Partials     2744      914     -1830     
Flag Coverage Δ
keras 24.73% <18.18%> (-58.37%) ⬇️
keras-jax 24.73% <18.18%> (-34.94%) ⬇️
keras-numpy ?
keras-openvino ?
keras-tensorflow ?
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ChiragSW
Copy link
Copy Markdown
Contributor Author

@hertschuh please review

@keerthanakadiri keerthanakadiri added the stat:awaiting keras-eng Awaiting response from Keras engineer label Mar 30, 2026
@amadhan882
Copy link
Copy Markdown

The fix in #22589 correctly addresses the numerical instability I reported in #22586. Disabling autocast to keep internal computations in float32 is the right approach and aligns with other normalization layers. Thanks for the quick fix @ChiragSW!

):
super().__init__(**kwargs)
self.supports_masking = True
self.autocast = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not hardcode self.autocast = False. The fix is indeed to do this:

keras.layers.GroupNormalization(groups=8, epsilon=1e-12, autocast=False)

But this should be controlled by users, not hardcoded.

The contract of autocast is to accept lower precision to improve speed, and that option should remain open to people who want it.

Now, we could print a warning if the epsilon is lower than the precision, because this is not achievable.

@hertschuh hertschuh added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Mar 31, 2026
@ChiragSW
Copy link
Copy Markdown
Contributor Author

ChiragSW commented Apr 2, 2026

The suggested changes are done, please review @hertschuh

gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
autocast=True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert all the changes related to autocast. We don't want any behavior change for backwards compatibility. It should stay false by default and people should opt in of they want that behavior.

initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
autocast=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undo this

initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
autocast=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undo this

Comment on lines +169 to +189
if self._epsilon_too_small_warning_issued:
return
if not self.autocast:
return
compute_dtype = backend.standardize_dtype(self.compute_dtype)
if compute_dtype not in ("float16", "bfloat16"):
return
try:
finfo = ml_dtypes.finfo(compute_dtype)
min_pos = getattr(finfo, "smallest_subnormal", finfo.tiny)
except Exception:
return
if self.epsilon != 0 and self.epsilon < float(min_pos):
warnings.warn(
"The configured `epsilon` is smaller than what can be "
f"represented in the layer compute dtype ({compute_dtype}); "
"it may be rounded to 0 under autocast. Consider increasing "
"`epsilon` or setting `autocast=False` for this layer.",
stacklevel=3,
)
self._epsilon_too_small_warning_issued = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this into __init__ and remove self._epsilon_too_small_warning_issued.

All of this will work after super().init() is called.

Comment on lines +174 to +175
if compute_dtype not in ("float16", "bfloat16"):
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this. The test on finfo below should apply if you're using float64 vs float32 for instance.

return
try:
finfo = ml_dtypes.finfo(compute_dtype)
min_pos = getattr(finfo, "smallest_subnormal", finfo.tiny)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smallest_subnormal and tiny are way too small, it would never catch the issue in the bug.

I think you should use finfo.eps instead.

Comment on lines +223 to +235
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
layer = layers.GroupNormalization(
groups=2,
axis=-1,
scale=False,
center=False,
epsilon=1e-12,
dtype="mixed_float16",
autocast=False,
)
_ = layer(x)
self.assertFalse(any("epsilon" in str(x.message) for x in w))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this test and the previous one line 212?

Comment on lines +218 to +219
epsilon=1e-12,
dtype="mixed_float16",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This combination should work after you use finfo.eps:

epsilon=1e-4,
dtype="mixed_bfloat16",

Comment on lines +184 to +207
def test_large_value_within_autocast_scope(self):
layer = layers.GroupNormalization(groups=2)
layer.build((1, 4, 4, 4))
large_value = ops.full(layer.gamma.shape, 70000)
with backend.AutocastScope("float16"):
layer.gamma.assign(large_value)
self.assertAllClose(layer.gamma.value, large_value)

def test_mixed_float16_large_inputs(self):
layer = layers.GroupNormalization(
groups=2, axis=-1, scale=False, center=False
)
x = np.full((1, 4, 4), 70000.0, dtype="float32")
with backend.AutocastScope("float16"):
output = layer(x)
output = backend.convert_to_numpy(output)
self.assertFalse(np.any(np.isnan(output)))

def test_autocast_is_user_controllable(self):
layer_default = layers.GroupNormalization()
self.assertTrue(layer_default.autocast)

layer_no_autocast = layers.GroupNormalization(autocast=False)
self.assertFalse(layer_no_autocast.autocast)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these, I don't think these tests exercise anything new.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants